# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Replay API client."""

import base64
import copy
import contextlib
import enum
import inspect
import io
import json
import os
import re
from typing import Any, Literal, Optional, Union, Iterator, AsyncIterator

import google.auth

from . import errors
from ._api_client import BaseApiClient
from ._api_client import HttpRequest
from ._api_client import HttpResponse
from ._common import BaseModel
from .types import HttpOptions, HttpOptionsOrDict


def to_snake_case(name: str) -> str:
  """Converts a string from camelCase or PascalCase to snake_case."""

  if not isinstance(name, str):
    name = str(name)
  s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
  return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


def _normalize_json_case(obj: Any) -> Any:
  if isinstance(obj, dict):
    return {
        to_snake_case(k): _normalize_json_case(v)
        for k, v in obj.items()
    }
  elif isinstance(obj, list):
    return [_normalize_json_case(item) for item in obj]
  elif isinstance(obj, enum.Enum):
    return obj.value
  elif isinstance(obj, str):
    # Python >= 3.14 has a new division by zero error message.
    if 'division by zero' in obj:
      return obj.replace(
          'division by zero', 'integer division or modulo by zero'
      )
  return obj


def _equals_ignore_key_case(obj1: Any, obj2: Any) -> bool:
  """Compares two Python objects for equality ignoring key casing.

  Returns:
      bool: True if the two objects are equal regardless of key casing
  (camelCase vs. snake_case). For example, the following are considered equal:

  {'my_key': 'my_value'}
  {'myKey': 'my_value'}

  This also considers enums and strings with the same value as equal.
  For example, the following are considered equal:

  {'type': <Type.STRING: 'STRING'>}}
  {'type': 'STRING'}
  """

  normalized_obj_1 = _normalize_json_case(obj1)
  normalized_obj_2 = _normalize_json_case(obj2)

  if normalized_obj_1 == normalized_obj_2:
    return True
  else:
    return False


def _redact_version_numbers(version_string: str) -> str:
  """Redacts version numbers in the form x.y.z from a string."""
  return re.sub(r'\d+\.\d+\.\d+[a-zA-Z0-9]*', '{VERSION_NUMBER}', version_string)


def _redact_language_label(language_label: str) -> str:
  """Removed because replay requests are used for all languages."""
  return re.sub(r'gl-python/', '{LANGUAGE_LABEL}/', language_label)


def _redact_request_headers(headers: dict[str, str]) -> dict[str, str]:
  """Redacts headers that should not be recorded."""
  redacted_headers = {}
  for header_name, header_value in headers.items():
    if header_name.lower() == 'x-goog-api-key':
      redacted_headers[header_name] = '{REDACTED}'
    elif header_name.lower() == 'user-agent':
      redacted_headers[header_name] = _redact_language_label(
          _redact_version_numbers(header_value)
      )
    elif header_name.lower() == 'x-goog-api-client':
      redacted_headers[header_name] = _redact_language_label(
          _redact_version_numbers(header_value)
      )
    elif header_name.lower() == 'x-goog-user-project':
      continue
    elif header_name.lower() == 'authorization':
      continue
    else:
      redacted_headers[header_name] = header_value
  return redacted_headers


def _redact_request_url(url: str) -> str:
  # Redact all the url parts before the resource name, so the test can work
  # against any project, location, version, or whether it's EasyGCP.
  result = re.sub(
      r'.*/projects/[^/]+/locations/[^/]+/',
      '{VERTEX_URL_PREFIX}/',
      url,
  )
  result = re.sub(
      r'.*-aiplatform.googleapis.com/[^/]+/',
      '{VERTEX_URL_PREFIX}/',
      result,
  )
  result = re.sub(
      r'.*aiplatform.googleapis.com/[^/]+/',
      '{VERTEX_URL_PREFIX}/',
      result,
  )
  result = re.sub(
      r'.*generativelanguage.*.googleapis.com/[^/]+',
      '{MLDEV_URL_PREFIX}',
      result,
  )
  return result


def _redact_project_location_path(path: str) -> str:
  # Redact a field in the request that is known to vary based on project and
  # location.
  if 'projects/' in path and 'locations/' in path:
    result = re.sub(
        r'projects/[^/]+/locations/[^/]+/',
        '{PROJECT_AND_LOCATION_PATH}/',
        path,
    )
    return result
  else:
    return path


def _redact_request_body(body: dict[str, object]) -> None:
  """Redacts fields in the request body in place."""
  for key, value in body.items():
    if isinstance(value, str):
      body[key] = _redact_project_location_path(value)


def redact_http_request(http_request: HttpRequest) -> None:
  http_request.headers = _redact_request_headers(http_request.headers)
  http_request.url = _redact_request_url(http_request.url)
  if not isinstance(http_request.data, bytes):
    _redact_request_body(http_request.data)


def _current_file_path_and_line() -> str:
  """Prints the current file path and line number."""
  current_frame = inspect.currentframe()
  if (
      current_frame is not None
      and current_frame.f_back is not None
      and current_frame.f_back.f_back is not None
  ):
    frame = current_frame.f_back.f_back
    filepath = inspect.getfile(frame)
    lineno = frame.f_lineno
    return f'File: {filepath}, Line: {lineno}'
  return ''


def _debug_print(message: str) -> None:
  print(
      'DEBUG (test',
      os.environ.get('PYTEST_CURRENT_TEST'),
      ')',
      _current_file_path_and_line(),
      ':\n    ',
      message,
  )


def pop_undeterministic_headers(headers: dict[str, str]) -> None:
  """Remove headers that are not deterministic."""
  headers.pop('Date', None)  # pytype: disable=attribute-error
  headers.pop('Server-Timing', None)  # pytype: disable=attribute-error


@contextlib.contextmanager
def _record_on_api_error(client: 'ReplayApiClient', http_request: HttpRequest) -> Iterator[None]:
  try:
    yield
  except errors.APIError as e:
    client._record_interaction(http_request, e)
    raise e

@contextlib.asynccontextmanager
async def _async_record_on_api_error(client: 'ReplayApiClient', http_request: HttpRequest) -> AsyncIterator[None]:
  try:
    yield
  except errors.APIError as e:
    client._record_interaction(http_request, e)
    raise e

class ReplayRequest(BaseModel):
  """Represents a single request in a replay."""

  method: str
  url: str
  headers: dict[str, str]
  body_segments: list[dict[str, object]]


class ReplayResponse(BaseModel):
  """Represents a single response in a replay."""

  status_code: int = 200
  headers: dict[str, str]
  body_segments: list[dict[str, object]]
  byte_segments: Optional[list[bytes]] = None
  sdk_response_segments: list[dict[str, object]]

  def model_post_init(self, __context: Any) -> None:
    pop_undeterministic_headers(self.headers)


class ReplayInteraction(BaseModel):
  """Represents a single interaction, request and response in a replay."""

  request: ReplayRequest
  response: ReplayResponse


class ReplayFile(BaseModel):
  """Represents a recorded session."""

  replay_id: str
  interactions: list[ReplayInteraction]


class ReplayApiClient(BaseApiClient):
  """For integration testing, send recorded response or records a response."""

  def __init__(
      self,
      mode: Literal['record', 'replay', 'auto', 'api'],
      replay_id: str,
      replays_directory: Optional[str] = None,
      vertexai: bool = False,
      api_key: Optional[str] = None,
      credentials: Optional[google.auth.credentials.Credentials] = None,
      project: Optional[str] = None,
      location: Optional[str] = None,
      http_options: Optional[HttpOptions] = None,
      private: bool = False,
  ):
    super().__init__(
        vertexai=vertexai,
        api_key=api_key,
        credentials=credentials,
        project=project,
        location=location,
        http_options=http_options,
    )
    self.replays_directory = replays_directory
    if not self.replays_directory:
      self.replays_directory = os.environ.get(
          'GOOGLE_GENAI_REPLAYS_DIRECTORY', None
      )
    # Valid replay modes are replay-only or record-and-replay.
    self.replay_session: Union[ReplayFile, None] = None
    self._mode = mode
    self._replay_id = replay_id
    self._private = private

  def initialize_replay_session(self, replay_id: str) -> None:
    self._replay_id = replay_id
    self._initialize_replay_session()

  def _get_replay_file_path(self) -> str:
    return self._generate_file_path_from_replay_id(
        self.replays_directory, self._replay_id
    )

  def _should_call_api(self) -> bool:
    return self._mode in ['record', 'api'] or (
        self._mode == 'auto'
        and not os.path.isfile(self._get_replay_file_path())
    )

  def _should_update_replay(self) -> bool:
    return self._should_call_api() and self._mode != 'api'

  def _initialize_replay_session_if_not_loaded(self) -> None:
    if not self.replay_session:
      self._initialize_replay_session()

  def _initialize_replay_session(self) -> None:
    _debug_print('Test is using replay id: ' + self._replay_id)
    self._replay_index = 0
    self._sdk_response_index = 0
    replay_file_path = self._get_replay_file_path()
    # This should not be triggered from the constructor.
    replay_file_exists = os.path.isfile(replay_file_path)
    if self._mode == 'replay' and not replay_file_exists:
      raise ValueError(
          'Replay files do not exist for replay id: ' + self._replay_id
      )

    if self._mode in ['replay', 'auto'] and replay_file_exists:
      with open(replay_file_path, 'r') as f:
        self.replay_session = ReplayFile.model_validate(json.loads(f.read()))

    if self._should_update_replay():
      self.replay_session = ReplayFile(
          replay_id=self._replay_id, interactions=[]
      )

  def _generate_file_path_from_replay_id(self, replay_directory: Optional[str], replay_id: str) -> str:
    session_parts = replay_id.split('/')
    if len(session_parts) < 3:
      raise ValueError(
          f'{replay_id}: Session ID must be in the format of'
          ' module/function/[vertex|mldev]'
      )
    if replay_directory is None:
      path_parts = []
    else:
      path_parts = [replay_directory]
    path_parts.extend(session_parts)
    return os.path.join(*path_parts) + '.json'

  def close(self) -> None:
    if not self._should_update_replay() or not self.replay_session:
      return
    replay_file_path = self._get_replay_file_path()
    os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
    with open(replay_file_path, 'w') as f:
      f.write(self.replay_session.model_dump_json(exclude_unset=True, indent=2))
    self.replay_session = None

  def _record_interaction(
      self,
      http_request: HttpRequest,
      http_response: Union[HttpResponse, errors.APIError, bytes],
  ) -> None:
    if not self._should_update_replay():
      return
    redact_http_request(http_request)
    request = ReplayRequest(
        method=http_request.method,
        url=http_request.url,
        headers=http_request.headers,
        body_segments=[http_request.data],
    )
    if isinstance(http_response, HttpResponse):
      response = ReplayResponse(
          headers=dict(http_response.headers),
          body_segments=list(http_response.segments()),
          byte_segments=[
              seg[:100] + b'...' for seg in http_response.byte_segments()
          ],
          status_code=http_response.status_code,
          sdk_response_segments=[],
      )
    elif isinstance(http_response, errors.APIError):
      response = ReplayResponse(
          headers=dict(http_response.response.headers),
          body_segments=[http_response._to_replay_record()],
          status_code=http_response.code,
          sdk_response_segments=[],
      )
    elif isinstance(http_response, bytes):
      response = ReplayResponse(
          headers={},
          body_segments=[],
          byte_segments=[http_response],
          sdk_response_segments=[],
      )
    else:
      raise ValueError(
          'Unsupported http_response type: ' + str(type(http_response))
      )
    if self.replay_session is None:
      raise ValueError('No replay session found.')
    self.replay_session.interactions.append(
        ReplayInteraction(request=request, response=response)
    )

  def _match_request(
      self,
      http_request: HttpRequest,
      interaction: ReplayInteraction,
  ) -> None:
    _debug_print(f'http_request.url: {http_request.url}')
    _debug_print(f'interaction.request.url: {interaction.request.url}')
    assert http_request.url == interaction.request.url
    assert http_request.headers == interaction.request.headers, (
        'Request headers mismatch:\n'
        f'Actual: {http_request.headers}\n'
        f'Expected: {interaction.request.headers}'
    )
    assert http_request.method == interaction.request.method

    # Sanitize the request body, rewrite any fields that vary.
    request_data_copy = copy.deepcopy(http_request.data)
    # Both the request and recorded request must be redacted before comparing
    # so that the comparison is fair.
    if not isinstance(request_data_copy, bytes):
      _redact_request_body(request_data_copy)

    actual_request_body = [request_data_copy]
    expected_request_body = interaction.request.body_segments
    assert _equals_ignore_key_case(actual_request_body, expected_request_body), (
        'Request body mismatch:\n'
        f'Actual: {actual_request_body}\n'
        f'Expected: {expected_request_body}'
    )

  def _build_response_from_replay(self, http_request: HttpRequest) -> HttpResponse:
    redact_http_request(http_request)

    if self.replay_session is None:
      raise ValueError('No replay session found.')
    interaction = self.replay_session.interactions[self._replay_index]
    # Replay is on the right side of the assert so the diff makes more sense.
    self._match_request(http_request, interaction)
    self._replay_index += 1
    self._sdk_response_index = 0
    errors.APIError.raise_for_response(interaction.response)
    http_response = HttpResponse(
        headers=interaction.response.headers,
        response_stream=[
            json.dumps(segment)
            for segment in interaction.response.body_segments
        ],
        byte_stream=interaction.response.byte_segments,
    )
    if http_response.response_stream == ['{}']:
      http_response.response_stream = [""]
    return http_response

  def _verify_response(self, response_model: BaseModel) -> None:
    if self._mode == 'api':
      return
    if not self.replay_session:
      raise ValueError('No replay session found.')
    # replay_index is advanced in _build_response_from_replay, so we need to -1.
    interaction = self.replay_session.interactions[self._replay_index - 1]
    if self._should_update_replay():
      if isinstance(response_model, list):
        response_model = response_model[0]
      sdk_response_response = getattr(response_model, 'sdk_http_response', None)
      if response_model and (
          sdk_response_response is not None
      ):
        headers = getattr(
            sdk_response_response, 'headers', None
        )
        if headers:
          pop_undeterministic_headers(headers)
      interaction.response.sdk_response_segments.append(
          response_model.model_dump(exclude_none=True)
      )
      return

    if isinstance(response_model, list):
      response_model = response_model[0]
    _debug_print(
        f'response_model: {response_model.model_dump(exclude_none=True)}'
    )
    actual = response_model.model_dump(exclude_none=True, mode='json')
    expected = interaction.response.sdk_response_segments[
        self._sdk_response_index
    ]
    # The sdk_http_response.body has format in the string, need to get rid of
    # the format information before comparing.
    if isinstance(expected, dict):
      if 'sdk_http_response' in expected and isinstance(
          expected['sdk_http_response'], dict
      ):
        if 'body' in expected['sdk_http_response']:
          raw_body = expected['sdk_http_response']['body']
          _debug_print(f'raw_body length: {len(raw_body)}')
          _debug_print(f'raw_body: {raw_body}')
          if isinstance(raw_body, str) and raw_body != '':
            raw_body = json.loads(raw_body)
            raw_body = json.dumps(raw_body)
            expected['sdk_http_response']['body'] = raw_body
    if not self._private:
      assert (
          actual == expected
      ), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
    else:
      _debug_print(f'Expected SDK response mismatch:\nActual: {actual}\nExpected: {expected}')
    self._sdk_response_index += 1

  def _request(
      self,
      http_request: HttpRequest,
      http_options: Optional[HttpOptionsOrDict] = None,
      stream: bool = False,
  ) -> HttpResponse:
    self._initialize_replay_session_if_not_loaded()
    if self._should_call_api():
      _debug_print('api mode request: %s' % http_request)
      with _record_on_api_error(self, http_request):
        result = super()._request(http_request, http_options, stream)
      if stream:
        result_segments = []
        for segment in result.segments():
          result_segments.append(json.dumps(segment))
        result = HttpResponse(result.headers, result_segments)
        self._record_interaction(http_request, result)
        # Need to return a RecordedResponse that rebuilds the response
        # segments since the stream has been consumed.
      else:
        self._record_interaction(http_request, result)
      _debug_print('api mode result: %s' % result.json)
      return result
    else:
      return self._build_response_from_replay(http_request)

  async def _async_request(
      self,
      http_request: HttpRequest,
      http_options: Optional[HttpOptionsOrDict] = None,
      stream: bool = False,
  ) -> HttpResponse:
    self._initialize_replay_session_if_not_loaded()
    if self._should_call_api():
      _debug_print('api mode request: %s' % http_request)
      async with _async_record_on_api_error(self, http_request):
        result = await super()._async_request(
            http_request, http_options, stream
        )
      if stream:
        result_segments = []
        async for segment in result.async_segments():
          result_segments.append(json.dumps(segment))
        result = HttpResponse(result.headers, result_segments)
        self._record_interaction(http_request, result)
        # Need to return a RecordedResponse that rebuilds the response
        # segments since the stream has been consumed.
      else:
        self._record_interaction(http_request, result)
      _debug_print('api mode result: %s' % result.json)
      return result
    else:
      return self._build_response_from_replay(http_request)

  def upload_file(
      self,
      file_path: Union[str, io.IOBase],
      upload_url: str,
      upload_size: int,
      *,
      http_options: Optional[HttpOptionsOrDict] = None,
  ) -> HttpResponse:
    if isinstance(file_path, io.IOBase):
      offset = file_path.tell()
      content = file_path.read()
      file_path.seek(offset, os.SEEK_SET)
      request = HttpRequest(
          method='POST',
          url='',
          data={'bytes': base64.b64encode(content).decode('utf-8')},
          headers={}
      )
    else:
      request = HttpRequest(
          method='POST', url='', data={'file_path': file_path}, headers={}
      )
    if self._should_call_api():
      result: Union[str, HttpResponse]
      with _record_on_api_error(self, request):
        result = super().upload_file(
            file_path, upload_url, upload_size, http_options=http_options
        )
      self._record_interaction(request, result)
      return result
    else:
      return self._build_response_from_replay(request)

  async def async_upload_file(
      self,
      file_path: Union[str, io.IOBase],
      upload_url: str,
      upload_size: int,
      *,
      http_options: Optional[HttpOptionsOrDict] = None,
  ) -> HttpResponse:
    if isinstance(file_path, io.IOBase):
      offset = file_path.tell()
      content = file_path.read()
      file_path.seek(offset, os.SEEK_SET)
      request = HttpRequest(
          method='POST',
          url='',
          data={'bytes': base64.b64encode(content).decode('utf-8')},
          headers={},
      )
    else:
      request = HttpRequest(
          method='POST', url='', data={'file_path': file_path}, headers={}
      )
    if self._should_call_api():
      result: HttpResponse
      async with _async_record_on_api_error(self, request):
        result = await super().async_upload_file(
            file_path, upload_url, upload_size, http_options=http_options
        )
      self._record_interaction(request, result)
      return result
    else:
      return self._build_response_from_replay(request)

  def download_file(
      self, path: str, *, http_options: Optional[HttpOptionsOrDict] = None
  ) -> Union[HttpResponse, bytes, Any]:
    self._initialize_replay_session_if_not_loaded()
    request = self._build_request(
        'get', path=path, request_dict={}, http_options=http_options
    )
    if self._should_call_api():
      with _record_on_api_error(self, request):
        result = super().download_file(path, http_options=http_options)
      self._record_interaction(request, result)
      return result
    else:
      return self._build_response_from_replay(request).byte_stream[0]

  async def async_download_file(
      self, path: str, *, http_options: Optional[HttpOptionsOrDict] = None
  ) -> Any:
    self._initialize_replay_session_if_not_loaded()
    request = self._build_request(
        'get', path=path, request_dict={}, http_options=http_options
    )
    if self._should_call_api():
      async with _async_record_on_api_error(self, request):
        result = await super().async_download_file(
            path, http_options=http_options
        )
      self._record_interaction(request, result)
      return result
    else:
      return self._build_response_from_replay(request).byte_stream[0]
